import argparse, time, os
from utils_data import *
from utils_algo import *
from models import *
import warnings
import scipy
 
warnings.filterwarnings("ignore")

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

parser = argparse.ArgumentParser(
    prog='Demo file for SC-Conf',
    usage='Demo with SC/Sub-Conf, NoRSC-Conf, and Naive Weighted method.',
    description='A simple demo file with Fashion-MNIST dataset.',
    epilog='end',
    add_help=True)

parser.add_argument('-lr', '--learning_rate', help='optimizer\'s learning rate', default=1e-4, type=float)
parser.add_argument('-bs', '--batch_size', help='batch_size of ordinary labels.', default=1000, type=int)
parser.add_argument('-e', '--epochs', help='number of epochs', type=int, default=100)
parser.add_argument('-wd', '--weight_decay', help='weight decay', default=1e-4, type=float)
parser.add_argument('-noise', '--noise', help='noisy_supervision', default=False, type=bool)


args = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
ordered_class = [7, 9]
print(ordered_class)

for T in range(1):
    sc_train_loader, naive_train_loader, norsc_train_loader, test_loader = prepare_mnist_data(args.batch_size, ordered_class, args.noise)
    me = ['unbiased', 'naive', 'norsc']
    for k in range(3):
        if k == 0:
            if args.noise:
                continue
            print('SC-Conf')
            train_loader = sc_train_loader
        elif k == 1:
            print('Weighted')
            train_loader = naive_train_loader
        else:
            print('NoRSC-Conf')
            train_loader = norsc_train_loader
        model = mlp3_model(input_dim=28 * 28, hidden_dim=500, output_dim=10)
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay, lr=args.learning_rate)
        test_accuracy = accuracy_check(loader=test_loader, model=model).to(device)
        print('Epoch: 0.  Te loss: {}'.format(test_accuracy))
        for epoch in range(1, args.epochs):
            for i, (images, weight) in enumerate(train_loader):
                images, weight = images.to(device), weight.to(device)
                optimizer.zero_grad()
                outputs = model(images).to(device)
                loss = SC_Conf_loss(output=outputs, weight=weight)
                loss.backward()
                optimizer.step()
            test_accuracy = accuracy_check(loader=test_loader, model=model).to(device)
            print('Epoch: {}.  Te acc: {}.'.format(epoch, test_accuracy))

